{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train & valid & test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "import os\n",
    "from sklearn.model_selection import KFold\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#set hyper param\n",
    "kfolds=5\n",
    "train_ratio=0.8\n",
    "seed=1\n",
    "no_cuda=False\n",
    "#set seed\n",
    "random.seed(seed)\n",
    "os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GIP_sim(matrix):\n",
    "    matrix=matrix.float()\n",
    "    fz=(matrix*matrix).sum(dim=1,keepdims=True)+(matrix*matrix).sum(dim=1,keepdims=True).T-2*matrix@matrix.T\n",
    "    fm=1/torch.diag(matrix@matrix.T).mean()\n",
    "    return torch.exp(-1*fz*fm)\n",
    "def Functional_sim(ass,sim,device,batch=1): #(a,b) #(b,b)\n",
    "    s1=ass.shape[0]      #a\n",
    "    ass=ass.to(device)\n",
    "    sim=sim.to(device)\n",
    "    sim_m=torch.zeros(s1,s1).to(device)  #(a,a)\n",
    "    iter_comb=torch.tensor(list(itertools.combinations(range(s1),2))).long().to(device)\n",
    "    for i in range(iter_comb.shape[0]//batch):\n",
    "        idx1,idx2=iter_comb[i*batch:(i+1)*batch,0],iter_comb[i*batch:(i+1)*batch,1]\n",
    "        m1=ass[idx1,:]\n",
    "        m2=ass[idx2,:]\n",
    "        sim1=m1[:,:,None]*sim*m2[:,None,:]  # (batch,b,b)\n",
    "        sim_m[idx1,idx2]=(sim1.max(dim=1)[0].sum(dim=-1)+sim1.max(dim=2)[0].sum(dim=-1))/(m1.sum(dim=-1)+m2.sum(dim=-1))\n",
    "    if iter_comb.shape[0]%batch!=0:\n",
    "        idx1,idx2=iter_comb[(i+1)*batch:,0],iter_comb[(i+1)*batch:,1]\n",
    "        m1=ass[idx1,:]\n",
    "        m2=ass[idx2,:]\n",
    "        sim1=m1[:,:,None]*sim*m2[:,None,:]  # (batch,b,b)\n",
    "        sim_m[idx1,idx2]=(sim1.max(dim=1)[0].sum(dim=-1)+sim1.max(dim=2)[0].sum(dim=-1))/(m1.sum(dim=-1)+m2.sum(dim=-1))\n",
    "    sim_m=torch.where(torch.isinf(sim_m),torch.zeros_like(sim_m),sim_m)\n",
    "    sim_m=torch.where(torch.isnan(sim_m),torch.zeros_like(sim_m),sim_m)\n",
    "    return (sim_m+sim_m.T+torch.eye(s1).to(device)).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set,test_set,common_set={},{},{}\n",
    "device=torch.device(\"cuda\" if (torch.cuda.is_available() and not no_cuda) else \"cpu\")\n",
    "#load data\n",
    "md=np.load('../miRNA_disease.npy')\n",
    "mm=np.load('../miRNA_miRNA.npy')\n",
    "ml=np.load('../miRNA_lncRNA.npy')\n",
    "dd=np.load('../disease_disease.npy')\n",
    "dl=np.load('../disease_lncRNA.npy')\n",
    "ll=np.load('../lncRNA_lncRNA.npy')\n",
    "\n",
    "common_set['md']=torch.tensor(md).long()\n",
    "common_set['mm_seq']=torch.tensor(mm).float()\n",
    "common_set['ml']=torch.tensor(ml).long()\n",
    "common_set['dd_sem']=torch.tensor(dd).float()\n",
    "common_set['dl']=torch.tensor(dl).long()\n",
    "common_set['ll_seq']=torch.tensor(ll).float()\n",
    "\n",
    "common_set['mm_mlG']=GIP_sim(common_set['ml']).float()\n",
    "common_set['dd_dlG']=GIP_sim(common_set['dl']).float()\n",
    "common_set['ll_lmG']=GIP_sim(common_set['ml'].T).float()\n",
    "common_set['ll_ldG']=GIP_sim(common_set['dl'].T).float()\n",
    "\n",
    "common_set['mm_mlF']=Functional_sim(common_set['ml'],common_set['ll_seq'],device,128).float()\n",
    "common_set['dd_dlF']=Functional_sim(common_set['dl'],common_set['ll_seq'],device,128).float()\n",
    "common_set['ll_ldF']=Functional_sim(common_set['dl'].T,common_set['dd_sem'],device,64).float()\n",
    "common_set['ll_lmF']=Functional_sim(common_set['ml'].T,common_set['mm_seq'],device,128).float()\n",
    "\n",
    "torch.save(common_set,'./common_set.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train test\n",
    "pos_x,pos_y=np.where(md==1)\n",
    "pos_xy=np.concatenate([pos_x[:,None],pos_y[:,None]],axis=1) #(23337, 2)\n",
    "pos_xy=pos_xy[np.random.permutation(pos_xy.shape[0]),:]\n",
    "train_pos_xy=pos_xy[:int(pos_xy.shape[0]*train_ratio),:]\n",
    "test_pos_xy=pos_xy[int(pos_xy.shape[0]*train_ratio):,:]\n",
    "\n",
    "neg_x,neg_y=np.where(md==0)\n",
    "neg_xy=np.concatenate([neg_x[:,None],neg_y[:,None]],axis=1) #(2562528, 2)\n",
    "neg_xy=neg_xy[np.random.permutation(neg_xy.shape[0]),:]\n",
    "train_neg_xy=neg_xy[:int(neg_xy.shape[0]*train_ratio),:]\n",
    "test_neg_xy=neg_xy[int(neg_xy.shape[0]*train_ratio):,:]\n",
    "\n",
    "train_xy=np.concatenate([train_pos_xy,train_neg_xy],axis=0)\n",
    "train_label=np.concatenate([np.ones(train_pos_xy.shape[0]),np.zeros(train_neg_xy.shape[0])],axis=0)\n",
    "train_rd=np.random.permutation(train_xy.shape[0])\n",
    "train_xy,train_label=train_xy[train_rd,:],train_label[train_rd]\n",
    "test_xy=np.concatenate([test_pos_xy,test_neg_xy],axis=0)\n",
    "test_label=np.concatenate([np.ones(test_pos_xy.shape[0]),np.zeros(test_neg_xy.shape[0])],axis=0)\n",
    "\n",
    "kf = KFold(n_splits=kfolds, shuffle=True, random_state=1)\n",
    "train_idx, valid_idx = [], []\n",
    "for train_index, valid_index in kf.split(train_xy):\n",
    "    train_idx.append(train_index)\n",
    "    valid_idx.append(valid_index)\n",
    "\n",
    "test_md=np.zeros(md.shape)\n",
    "test_md[train_pos_xy[:,0],train_pos_xy[:,1]]=1\n",
    "test_set['edge']=torch.tensor(test_xy).long()\n",
    "test_set['label']=torch.tensor(test_label).long()\n",
    "test_set['md']=torch.tensor(test_md).long()\n",
    "test_set['mm_mdG']=GIP_sim(test_set['md']).float()\n",
    "test_set['dd_dmG']=GIP_sim(test_set['md'].T).float()\n",
    "test_set['mm_mdF']=Functional_sim(test_set['md'],common_set['dd_sem'],device,64).float()\n",
    "test_set['dd_dmF']=Functional_sim(test_set['md'].T,common_set['mm_seq'],device,64).float()\n",
    "\n",
    "torch.save(test_set,'./test_set.pkl')\n",
    "print('test_set saved')\n",
    "\n",
    "for k in range(kfolds):\n",
    "    xy_train,xy_valid=train_xy[train_idx[k],:],train_xy[valid_idx[k],:]\n",
    "    label_train,label_valid=train_label[train_idx[k]],train_label[valid_idx[k]]\n",
    "    train_md=np.zeros(md.shape)\n",
    "    train_md[xy_train[:,0],xy_train[:,1]]=label_train\n",
    "    train_set['edge_train_%d'%k]=torch.tensor(xy_train).long()\n",
    "    train_set['label_train_%d'%k]=torch.tensor(label_train).long()\n",
    "    train_set['edge_valid_%d'%k]=torch.tensor(xy_valid).long()\n",
    "    train_set['label_valid_%d'%k]=torch.tensor(label_valid).long()\n",
    "    train_set['md_%d'%k]=torch.tensor(train_md).long()\n",
    "    train_set['mm_mdG_%d'%k]=GIP_sim(train_set['md_%d'%k]).float()\n",
    "    train_set['dd_dmG_%d'%k]=GIP_sim(train_set['md_%d'%k].T).float()\n",
    "    train_set['mm_mdF_%d'%k]=Functional_sim(train_set['md_%d'%k],common_set['dd_sem'],device,64).float()\n",
    "    train_set['dd_dmF_%d'%k]=Functional_sim(train_set['md_%d'%k].T,common_set['mm_seq'],device,64).float()\n",
    "\n",
    "torch.save(train_set,'./train_set.pkl')\n",
    "print('train_set saved')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
